Skip to content

Commit

Permalink
[SPARK-11587][SPARKR] Fix the summary generic to match base R
Browse files Browse the repository at this point in the history
The signature is summary(object, ...) as defined in
https://stat.ethz.ch/R-manual/R-devel/library/base/html/summary.html

Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu>

Closes #9582 from shivaram/summary-fix.
  • Loading branch information
shivaram committed Nov 10, 2015
1 parent 1431319 commit c4e19b3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
6 changes: 3 additions & 3 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1944,9 +1944,9 @@ setMethod("describe",
#' @rdname summary
#' @name summary
setMethod("summary",
signature(x = "DataFrame"),
function(x) {
describe(x)
signature(object = "DataFrame"),
function(object, ...) {
describe(object)
})


Expand Down
2 changes: 1 addition & 1 deletion R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") })

#' @rdname summary
#' @export
setGeneric("summary", function(x, ...) { standardGeneric("summary") })
setGeneric("summary", function(object, ...) { standardGeneric("summary") })

# @rdname tojson
# @export
Expand Down
12 changes: 6 additions & 6 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ setMethod("predict", signature(object = "PipelineModel"),
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(x = "PipelineModel"),
function(x, ...) {
setMethod("summary", signature(object = "PipelineModel"),
function(object, ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", x@model)
"getModelName", object@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", x@model)
"getModelFeatures", object@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelCoefficients", x@model)
"getModelCoefficients", object@model)
if (modelName == "LinearRegressionModel") {
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelDevianceResiduals", x@model)
"getModelDevianceResiduals", object@model)
devianceResiduals <- matrix(devianceResiduals, nrow = 1)
colnames(devianceResiduals) <- c("Min", "Max")
rownames(devianceResiduals) <- rep("", times = 1)
Expand Down
6 changes: 6 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ test_that("summary coefficients match with native glm of family 'binomial'", {
rownames(stats$Coefficients) ==
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
})

test_that("summary works on base GLM models", {
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})

0 comments on commit c4e19b3

Please sign in to comment.