Skip to content

Commit

Permalink
[SPARK-14313][ML][SPARKR] AFTSurvivalRegression model persistence in …
Browse files Browse the repository at this point in the history
…SparkR

## What changes were proposed in this pull request?
```AFTSurvivalRegressionModel``` supports ```save/load``` in SparkR.

## How was this patch tested?
Unit tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12685 from yanboliang/spark-14313.
  • Loading branch information
yanboliang authored and mengxr committed Apr 26, 2016
1 parent 162cf02 commit 92f6633
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 3 deletions.
27 changes: 27 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,31 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
invisible(callJMethod(writer, "save", path))
})

#' Save the AFT survival regression model to the input path.
#'
#' @param object A fitted AFT survival regression model
#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname ml.save
#' @name ml.save
#' @export
#' @examples
#' \dontrun{
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
#' path <- "path/to/model"
#' ml.save(model, path)
#' }
setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
})

#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
Expand All @@ -381,6 +406,8 @@ ml.load <- function(path) {
jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
return(new("NaiveBayesModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
return(new("AFTSurvivalRegressionModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
Expand Down
13 changes: 13 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,19 @@ test_that("survreg", {
expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
2.390146, 2.891269, 2.891269), tolerance = 1e-4)

# Test model save/load
modelPath <- tempfile(pattern = "survreg", fileext = ".tmp")
ml.save(model, modelPath)
expect_error(ml.save(model, modelPath))
ml.save(model, modelPath, overwrite = TRUE)
model2 <- ml.load(modelPath)
stats2 <- summary(model2)
coefs2 <- as.vector(stats2$coefficients[, 1])
expect_equal(coefs, coefs2)
expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients))

unlink(modelPath)

# Test survival::survreg
if (requireNamespace("survival", quietly = TRUE)) {
rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@

package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}

private[r] class AFTSurvivalRegressionWrapper private (
pipeline: PipelineModel,
features: Array[String]) {
val pipeline: PipelineModel,
val features: Array[String]) extends MLWritable {

private val aftModel: AFTSurvivalRegressionModel =
pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]
Expand All @@ -46,9 +53,12 @@ private[r] class AFTSurvivalRegressionWrapper private (
def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset).drop(aftModel.getFeaturesCol)
}

override def write: MLWriter =
new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperWriter(this)
}

private[r] object AFTSurvivalRegressionWrapper {
private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] {

private def formulaRewrite(formula: String): (String, String) = {
var rewritedFormula: String = null
Expand Down Expand Up @@ -96,4 +106,40 @@ private[r] object AFTSurvivalRegressionWrapper {

new AFTSurvivalRegressionWrapper(pipeline, features)
}

override def read: MLReader[AFTSurvivalRegressionWrapper] = new AFTSurvivalRegressionWrapperReader

override def load(path: String): AFTSurvivalRegressionWrapper = super.load(path)

class AFTSurvivalRegressionWrapperWriter(instance: AFTSurvivalRegressionWrapper)
extends MLWriter {

override protected def saveImpl(path: String): Unit = {
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString

val rMetadata = ("class" -> instance.getClass.getName) ~
("features" -> instance.features.toSeq)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)

instance.pipeline.save(pipelinePath)
}
}

class AFTSurvivalRegressionWrapperReader extends MLReader[AFTSurvivalRegressionWrapper] {

override def load(path: String): AFTSurvivalRegressionWrapper = {
implicit val format = DefaultFormats
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString

val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val features = (rMetadata \ "features").extract[Array[String]]

val pipeline = PipelineModel.load(pipelinePath)
new AFTSurvivalRegressionWrapper(pipeline, features)
}
}
}
2 changes: 2 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ private[r] object RWrappers extends MLReader[Object] {
val className = (rMetadata \ "class").extract[String]
className match {
case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
AFTSurvivalRegressionWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR ml.load does not support load $className")
}
Expand Down

0 comments on commit 92f6633

Please sign in to comment.