Skip to content

Commit

Permalink
Model import/export for IsotonicRegression
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Mar 30, 2015
1 parent 19d4c39 commit 2b2f5a1
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ import java.util.Arrays.binarySearch

import scala.collection.mutable.ArrayBuffer

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}

/**
* :: Experimental ::
Expand All @@ -42,7 +48,7 @@ import org.apache.spark.rdd.RDD
class IsotonicRegressionModel (
val boundaries: Array[Double],
val predictions: Array[Double],
val isotonic: Boolean) extends Serializable {
val isotonic: Boolean) extends Serializable with Saveable {

private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse

Expand Down Expand Up @@ -124,6 +130,71 @@ class IsotonicRegressionModel (
predictions(foundIndex)
}
}

override def save(sc: SparkContext, path: String): Unit = {
val data = IsotonicRegressionModel.SaveLoadV1_0.Data(boundaries, predictions, isotonic)
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
}

object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {

def thisFormatVersion: String = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"

/** Model data for model import/export */
case class Data(boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean)

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(dataPath(path))

checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("boundaries", "predictions", "isotonic").take(1)
assert(dataArray.size == 1,
s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}")
val data = dataArray(0)
val boundaries = data.getAs[Seq[Double]](0).toArray
val predictions = data.getAs[Seq[Double]](1).toArray
val isotonic = data.getAs[Boolean](2)
new IsotonicRegressionModel(boundaries, predictions, isotonic)
}
}

override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val model = SaveLoadV1_0.load(sc, path)
model
case _ => throw new Exception(
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)"
)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite}

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {

Expand Down Expand Up @@ -73,6 +74,24 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
assert(model.isotonic)
}

test("model save/load") {
val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

// Save model, load it back, and compare.
try {
model.save(sc, path)
val sameModel = IsotonicRegressionModel.load(sc, path)
assert(model.boundaries === sameModel.boundaries)
assert(model.predictions === sameModel.predictions)
assert(model.isotonic == model.isotonic)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("isotonic regression with size 0") {
val model = runIsotonicRegression(Seq(), true)

Expand Down

0 comments on commit 2b2f5a1

Please sign in to comment.