Skip to content

Commit

Permalink
store each interval as a record
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Apr 4, 2015
1 parent 2b2f5a1 commit 429ff7d
Showing 1 changed file with 17 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.Arrays.binarySearch

import scala.collection.mutable.ArrayBuffer

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

Expand Down Expand Up @@ -132,8 +133,9 @@ class IsotonicRegressionModel (
}

override def save(sc: SparkContext, path: String): Unit = {
val data = IsotonicRegressionModel.SaveLoadV1_0.Data(boundaries, predictions, isotonic)
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data)
val intervals = boundaries.toList.zip(predictions.toList).toArray
val data = IsotonicRegressionModel.SaveLoadV1_0.Data(intervals)
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data, isotonic)
}

override protected def formatVersion: String = "1.0"
Expand All @@ -151,43 +153,45 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"

/** Model data for model import/export */
case class Data(boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean)
case class Data(intervals: Array[(Double, Double)])

def save(sc: SparkContext, path: String, data: Data): Unit = {
def save(sc: SparkContext, path: String, data: Data, isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(dataPath(path))

checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("boundaries", "predictions", "isotonic").take(1)
val dataArray = dataRDD.select("intervals").take(1)
assert(dataArray.size == 1,
s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}")
val data = dataArray(0)
val boundaries = data.getAs[Seq[Double]](0).toArray
val predictions = data.getAs[Seq[Double]](1).toArray
val isotonic = data.getAs[Boolean](2)
new IsotonicRegressionModel(boundaries, predictions, isotonic)
val intervals = data.getAs[Seq[(Double, Double)]](0)
val (boundaries, predictions) = intervals.unzip
(boundaries.toArray, predictions.toArray)
}
}

override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
implicit val formats = DefaultFormats
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val isotonic = (metadata \ "isotonic").extract[Boolean]
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val model = SaveLoadV1_0.load(sc, path)
model
val (boundaries, predictions) = SaveLoadV1_0.load(sc, path)
new IsotonicRegressionModel(boundaries, predictions, isotonic)
case _ => throw new Exception(
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
Expand Down

0 comments on commit 429ff7d

Please sign in to comment.