Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30820][SPARKR][ML] Add FMClassifier to SparkR #27570

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ exportMethods("glm",
"spark.freqItemsets",
"spark.associationRules",
"spark.findFrequentSequentialPatterns",
"spark.assignClusters")
"spark.assignClusters",
"spark.fmClassifier")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
setGeneric("spark.bisectingKmeans",
function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") })

#' @rdname spark.fmClassifier
setGeneric("spark.fmClassifier",
function(data, formula, ...) { standardGeneric("spark.fmClassifier") })

#' @rdname spark.gaussianMixture
setGeneric("spark.gaussianMixture",
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
Expand Down
157 changes: 157 additions & 0 deletions R/pkg/R/mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))

#' S4 class that represents a FMClassificationModel
#'
#' @param jobj a Java object reference to the backing Scala FMClassifierWrapper
#' @note FMClassificationModel since 3.1.0
zero323 marked this conversation as resolved.
Show resolved Hide resolved
setClass("FMClassificationModel", representation(jobj = "jobj"))

#' Linear SVM Model
#'
#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
Expand Down Expand Up @@ -649,3 +655,154 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' Factorization Machines Classification Model
#'
#' \code{spark.fmClassifier} fits a factorization classification model against a SparkDataFrame.
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#' Only categorical data is supported.
#'
#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param factorSize dimensionality of the factors.
#' @param fitLinear whether to fit linear term. # TODO Can we express this with formula?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked this TODO yet?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it more for a discussion. Adding custom formula components is not very hard, the question is if it makes sense to complicate for such thing.

#' @param regParam the regularization parameter.
#' @param miniBatchFraction the mini-batch fraction parameter.
#' @param initStd the standard deviation of initial coefficients.
#' @param maxIter maximum iteration number.
#' @param stepSize stepSize parameter.
#' @param tol convergence tolerance of iterations.
#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "adamW".
#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of
#' class label 1 is > threshold, then predict 1, else 0. A high threshold
#' encourages the model to predict 0 more often; a low threshold encourages the
#' model to predict 1 more often. Note: Setting this with threshold p is
#' equivalent to setting thresholds c(1-p, p).
#' @param seed seed parameter for weights initialization.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and
#' label column of string type.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in
#' a special additional bucket, at index numLabels). Default
#' is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.fmClassifier} returns a fitted Factorization Machines Classification Model.
#' @rdname spark.fmClassifier
#' @aliases spark.fmClassifier,SparkDataFrame,formula-method
#' @name spark.fmClassifier
#' @seealso \link{read.ml}
#' @examples
#' \dontrun{
#' df <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm")
#'
#' # fit Factorization Machines Classification Model
#' model <- spark.fmClassifier(
#' df, label ~ features,
#' regParam = 0.01, maxIter = 10, fitLinear = TRUE
#' )
#'
#' # get the summary of the model
#' summary(model)
#'
#' # make predictions
#' predictions <- predict(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#' }
#' @note spark.fmClassifier since 3.1.0
setMethod("spark.fmClassifier", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, factorSize = 8, fitLinear = TRUE, regParam = 0.0,
miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100, stepSize=1.0,
tol = 1e-6, solver = c("adamW", "gd"), thresholds = NULL, seed = NULL,
handleInvalid = c("error", "keep", "skip")) {
zero323 marked this conversation as resolved.
Show resolved Hide resolved

formula <- paste(deparse(formula), collapse = "")

if (!is.null(seed)) {
seed <- as.character(as.integer(seed))
}

if (!is.null(thresholds)) {
thresholds <- as.list(thresholds)
}

solver <- match.arg(solver)
handleInvalid <- match.arg(handleInvalid)

jobj <- callJStatic("org.apache.spark.ml.r.FMClassifierWrapper",
"fit",
data@sdf,
formula,
as.integer(factorSize),
as.logical(fitLinear),
as.numeric(regParam),
as.numeric(miniBatchFraction),
as.numeric(initStd),
as.integer(maxIter),
as.numeric(stepSize),
as.numeric(tol),
solver,
seed,
thresholds,
handleInvalid)
new("FMClassificationModel", jobj = jobj)
})

# Returns the summary of a FM Classification model produced by \code{spark.fmClassifier}

#' @param object a FM Classification model fitted by \code{spark.fmClassifier}.
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' @rdname spark.fmClassifier
#' @note summary(FMClassificationModel) since 3.1.0
setMethod("summary", signature(object = "FMClassificationModel"),
function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
numClasses <- callJMethod(jobj, "numClasses")
numFeatures <- callJMethod(jobj, "numFeatures")
raw_factors <- unlist(callJMethod(jobj, "rFactors"))
factor_size <- callJMethod(jobj, "factorSize")

list(
coefficients = coefficients,
factors = matrix(raw_factors, ncol = factor_size),
numClasses = numClasses, numFeatures = numFeatures,
factorSize = factor_size
)
})

# Predicted values based on an FMClassificationModel model

#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns the predicted values based on a FM Classification model.
#' @rdname spark.fmClassifier
#' @aliases predict,FMClassificationModel,SparkDataFrame-method
#' @note predict(FMClassificationModel) since 3.1.0
setMethod("predict", signature(object = "FMClassificationModel"),
function(object, newData) {
predict_internal(object, newData)
})

# Save fitted FMClassificationModel to the input path

#' @param path The directory where the model is saved.
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.fmClassifier
#' @aliases write.ml,FMClassificationModel,character-method
#' @note write.ml(FMClassificationModel, character) since 3.1.0
setMethod("write.ml", signature(object = "FMClassificationModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
2 changes: 2 additions & 0 deletions R/pkg/R/mllib_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ read.ml <- function(path) {
new("LinearSVCModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
new("FPGrowthModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMClassifierWrapper")) {
new("FMClassificationModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
Expand Down
34 changes: 34 additions & 0 deletions R/pkg/tests/fulltests/test_mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,38 @@ test_that("spark.naiveBayes", {
expect_equal(class(collect(predictions)$clicked[1]), "character")
})

test_that("spark.fmClassifier", {
df <- withColumn(
suppressWarnings(createDataFrame(iris)),
"Species", otherwise(when(column("Species") == "Setosa", "Setosa"), "Not-Setosa")
)

model1 <- spark.fmClassifier(
df, Species ~ .,
regParam = 0.01, maxIter = 10, fitLinear = TRUE, factorSize = 3
)

prediction1 <- predict(model1, df)
expect_is(prediction1, "SparkDataFrame")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also check the predict result here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if such check are really useful here. In practice fitting is not unlikely failure point and most likely problems are related to parameter passing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked other classification tests. It seems other tests checked the typeof and result of the prediction. I guess it might be better to be consistent with other tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typeof is not applicable here. typeof is S compatibility thingy, and can be used only to distinguish between core types (here it could only determine if value is S4 type).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that all the other ML R tests check the prediction result. For example, in LinearSVM,

  # Test prediction with string label
  prediction <- predict(model, training)
  expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character")
  expected <- c("versicolor", "versicolor", "versicolor", "virginica",  "virginica",
                "virginica",  "virginica",  "virginica",  "virginica",  "virginica")
  expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected)

Is it OK if we do something similar here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure we can. The question is what we are really trying to test in such cases? What types of implementation mistakes can we detect here, that are not already covered by JVM tests and / or SparkR data frames tests?

These checks involve additional jobs and many tests are already rejected to keep things manageable, so unless these serve specific purpose, I'd prefer to keep things lean here.

In contrast there are many SparkR ML failure modes that are real, and could be tested, but are crippled by lack of required API. But that's way beyond the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK with this.

expect_equal(summary(model1)$factorSize, 3)

# Test model save/load
if (windows_with_hadoop()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why this check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to avoid failures in case of missing winutils. If i recall correctly the primary target was CRAN tests (and these shouldn't run here anyway), but I think it still applicable to AppVeyor.

modelPath <- tempfile(pattern = "spark-fmclassifier", fileext = ".tmp")
write.ml(model1, modelPath)
model2 <- read.ml(modelPath)

expect_is(model2, "FMClassificationModel")

expect_equal(summary(model1), summary(model2))

prediction2 <- predict(model2, df)
expect_equal(
collect(drop(prediction1, c("rawPrediction", "probability"))),
collect(drop(prediction2, c("rawPrediction", "probability")))
)
zero323 marked this conversation as resolved.
Show resolved Hide resolved
unlink(modelPath)
}
})

sparkR.session.stop()
20 changes: 20 additions & 0 deletions R/pkg/vignettes/sparkr-vignettes.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ SparkR supports the following machine learning models and algorithms.

* Naive Bayes

* Factorization Machines (FM) Classifier

#### Regression

* Accelerated Failure Time (AFT) Survival Model
Expand Down Expand Up @@ -705,6 +707,24 @@ naiveBayesPrediction <- predict(naiveBayesModel, titanicDF)
head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction"))
```

#### Factorization Machines Classifier

Factorization Machines for classification problems.

For background and details about the implementation of factorization machines,
refer to the [Factorization Machines section](https://spark.apache.org/docs/latest/ml-classification-regression.html#factorization-machines).

```{r}
t <- as.data.frame(Titanic)
training <- createDataFrame(t)

model <- spark.fmClassifier(training, Survived ~ Age + Sex)
summary(model)

predictions <- predict(model, training)
head(select(predictions, predictions$prediction))
```

#### Accelerated Failure Time Survival Model

Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring.
Expand Down
Loading