Skip to content

Commit

Permalink
[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel
Browse files Browse the repository at this point in the history
See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988).

Author: Xusen Yin <yinxusen@gmail.com>

Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits:

cb1ecfa [Xusen Yin] change Assignment into case class
b1dd24c [Xusen Yin] add test suite
63c3923 [Xusen Yin] add save load for power iteration clustering
  • Loading branch information
yinxusen authored and mengxr committed Apr 13, 2015
1 parent 6cc5b3e commit 1e340c3
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@

package org.apache.spark.mllib.clustering

import org.apache.spark.{Logging, SparkException}
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.{Logging, SparkContext, SparkException}

/**
* :: Experimental ::
Expand All @@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class PowerIterationClusteringModel(
val k: Int,
val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {

override def save(sc: SparkContext, path: String): Unit = {
PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
}

override protected def formatVersion: String = "1.0"
}

object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
}

private[clustering]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"

def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val dataRDD = model.assignments.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)

val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)

val k = (metadata \ "k").extract[Int]
val assignments = sqlContext.parquetFile(Loader.dataPath(path))
Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)

val assignmentsRDD = assignments.map {
case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster)
}

new PowerIterationClusteringModel(k, assignmentsRDD)
}
}
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
val v = powerIter(w, maxIterations)
val assignments = kMeans(v, k).mapPartitions({ iter =>
iter.map { case (id, cluster) =>
new Assignment(id, cluster)
Assignment(id, cluster)
}
}, preservesPartitioning = true)
new PowerIterationClusteringModel(k, assignments)
Expand All @@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
* @param cluster assigned cluster id
*/
@Experimental
class Assignment(val id: Long, val cluster: Int) extends Serializable
case class Assignment(id: Long, cluster: Int)

/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.mllib.clustering

import scala.collection.mutable
import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.SparkContext
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
assert(x ~== u1(i.toInt) absTol 1e-14)
}
}

test("model save/load") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
try {
model.save(sc, path)
val sameModel = PowerIterationClusteringModel.load(sc, path)
PowerIterationClusteringSuite.checkEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

object PowerIterationClusteringSuite extends FunSuite {
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
val assignments = sc.parallelize(
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
new PowerIterationClusteringModel(k, assignments)
}

def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
assert(a.k === b.k)

val aAssignments = a.assignments.map(x => (x.id, x.cluster))
val bAssignments = b.assignments.map(x => (x.id, x.cluster))
val unequalElements = aAssignments.join(bAssignments).filter {
case (id, (c1, c2)) => c1 != c2 }.count()
assert(unequalElements === 0L)
}
}

0 comments on commit 1e340c3

Please sign in to comment.