Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Apr 9, 2015
1 parent 429ff7d commit 49600cc
Showing 1 changed file with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ class IsotonicRegressionModel (
}

override def save(sc: SparkContext, path: String): Unit = {
val intervals = boundaries.toList.zip(predictions.toList).toArray
val data = IsotonicRegressionModel.SaveLoadV1_0.Data(intervals)
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data, isotonic)
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
}

override protected def formatVersion: String = "1.0"
Expand All @@ -153,9 +151,14 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"

/** Model data for model import/export */
case class Data(intervals: Array[(Double, Double)])

def save(sc: SparkContext, path: String, data: Data, isotonic: Boolean): Unit = {
case class Data(boundary: Double, prediction: Double)

def save(
sc: SparkContext,
path: String,
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

Expand All @@ -164,21 +167,18 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
sqlContext.createDataFrame(boundaries.toList.zip(predictions.toList)
.map { case (b, p) => Data(b, p) }).saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(dataPath(path))

checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("intervals").take(1)
assert(dataArray.size == 1,
s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}")
val data = dataArray(0)
val intervals = data.getAs[Seq[(Double, Double)]](0)
val (boundaries, predictions) = intervals.unzip
val dataArray = dataRDD.select("boundary", "prediction").collect()
val (boundaries, predictions) = dataArray.map {
x => (x.getAs[Double](0), x.getAs[Double](1)) }.toList.unzip
(boundaries.toArray, predictions.toArray)
}
}
Expand Down

0 comments on commit 49600cc

Please sign in to comment.