From 429ff7dd71384fcb36f355930611673ae8c84e71 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 4 Apr 2015 13:47:56 -0400 Subject: [PATCH] store each interval as a record --- .../mllib/regression/IsotonicRegression.scala | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ce35c6f5558e2..2e51834f3ba9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -23,6 +23,7 @@ import java.util.Arrays.binarySearch import scala.collection.mutable.ArrayBuffer +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -132,8 +133,9 @@ class IsotonicRegressionModel ( } override def save(sc: SparkContext, path: String): Unit = { - val data = IsotonicRegressionModel.SaveLoadV1_0.Data(boundaries, predictions, isotonic) - IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data) + val intervals = boundaries.toList.zip(predictions.toList).toArray + val data = IsotonicRegressionModel.SaveLoadV1_0.Data(intervals) + IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data, isotonic) } override protected def formatVersion: String = "1.0" @@ -151,43 +153,45 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" /** Model data for model import/export */ - case class Data(boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean) + case class Data(intervals: Array[(Double, Double)]) - def save(sc: SparkContext, path: String, data: Data): Unit = { + def save(sc: SparkContext, path: String, data: Data, isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() dataRDD.saveAsParquetFile(dataPath(path)) } - def load(sc: SparkContext, path: String): IsotonicRegressionModel = { + def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { val sqlContext = new SQLContext(sc) val dataRDD = sqlContext.parquetFile(dataPath(path)) checkSchema[Data](dataRDD.schema) - val dataArray = dataRDD.select("boundaries", "predictions", "isotonic").take(1) + val dataArray = dataRDD.select("intervals").take(1) assert(dataArray.size == 1, s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}") val data = dataArray(0) - val boundaries = data.getAs[Seq[Double]](0).toArray - val predictions = data.getAs[Seq[Double]](1).toArray - val isotonic = data.getAs[Boolean](2) - new IsotonicRegressionModel(boundaries, predictions, isotonic) + val intervals = data.getAs[Seq[(Double, Double)]](0) + val (boundaries, predictions) = intervals.unzip + (boundaries.toArray, predictions.toArray) } } override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { + implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val isotonic = (metadata \ "isotonic").extract[Boolean] val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val model = SaveLoadV1_0.load(sc, path) - model + val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) + new IsotonicRegressionModel(boundaries, predictions, isotonic) case _ => throw new Exception( s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" +