From 49600cc3be22383d787d40a73cae09ed7106c083 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 9 Apr 2015 13:58:39 -0400 Subject: [PATCH] address comments --- .../mllib/regression/IsotonicRegression.scala | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 2e51834f3ba9a..2144c19f12c0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -133,9 +133,7 @@ class IsotonicRegressionModel ( } override def save(sc: SparkContext, path: String): Unit = { - val intervals = boundaries.toList.zip(predictions.toList).toArray - val data = IsotonicRegressionModel.SaveLoadV1_0.Data(intervals) - IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data, isotonic) + IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) } override protected def formatVersion: String = "1.0" @@ -153,9 +151,14 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" /** Model data for model import/export */ - case class Data(intervals: Array[(Double, Double)]) - - def save(sc: SparkContext, path: String, data: Data, isotonic: Boolean): Unit = { + case class Data(boundary: Double, prediction: Double) + + def save( + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ @@ -164,8 +167,8 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) - val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + sqlContext.createDataFrame(boundaries.toList.zip(predictions.toList) + .map { case (b, p) => Data(b, p) }).saveAsParquetFile(dataPath(path)) } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { @@ -173,12 +176,9 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { val dataRDD = sqlContext.parquetFile(dataPath(path)) checkSchema[Data](dataRDD.schema) - val dataArray = dataRDD.select("intervals").take(1) - assert(dataArray.size == 1, - s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}") - val data = dataArray(0) - val intervals = data.getAs[Seq[(Double, Double)]](0) - val (boundaries, predictions) = intervals.unzip + val dataArray = dataRDD.select("boundary", "prediction").collect() + val (boundaries, predictions) = dataArray.map { + x => (x.getAs[Double](0), x.getAs[Double](1)) }.toList.unzip (boundaries.toArray, predictions.toArray) } }