Skip to content

Commit

Permalink
Fixed MlWritable and MlReable to JavaMLWritable and JavaMlReadable
Browse files Browse the repository at this point in the history
  • Loading branch information
GayathriMurali committed Mar 23, 2016
2 parents 540fe8e + 5dfc019 commit bad5c3e
Show file tree
Hide file tree
Showing 199 changed files with 6,328 additions and 2,191 deletions.
3 changes: 2 additions & 1 deletion R/pkg/DESCRIPTION
Expand Up @@ -11,7 +11,8 @@ Depends:
R (>= 3.0),
methods,
Suggests:
testthat
testthat,
e1071
Description: R frontend for Spark
License: Apache License (== 2.0)
Collate:
Expand Down
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Expand Up @@ -15,7 +15,8 @@ exportMethods("glm",
"predict",
"summary",
"kmeans",
"fitted")
"fitted",
"naiveBayes")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Expand Up @@ -1175,3 +1175,7 @@ setGeneric("kmeans")
#' @rdname fitted
#' @export
setGeneric("fitted")

#' @rdname naiveBayes
#' @export
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
91 changes: 86 additions & 5 deletions R/pkg/R/mllib.R
Expand Up @@ -22,6 +22,11 @@
#' @export
setClass("PipelineModel", representation(model = "jobj"))

#' @title S4 class that represents a NaiveBayesModel
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
#' @export
setClass("NaiveBayesModel", representation(jobj = "jobj"))

#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
Expand All @@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' data(iris)
Expand Down Expand Up @@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
#' @rdname predict
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
Expand All @@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})

#' Make predictions from a naive Bayes model
#'
#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
#'
#' @param object A fitted naive Bayes model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted labels in a column named "prediction"
#' @rdname predict
#' @export
#' @examples
#' \dontrun{
#' model <- naiveBayes(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
setMethod("predict", signature(object = "NaiveBayesModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
Expand All @@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"),
#' @rdname summary
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
Expand Down Expand Up @@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"),
}
})

#' Get the summary of a naive Bayes model
#'
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
#'
#' @param object A fitted MLlib model
#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
# probabilities given the target label
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- naiveBayes(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "NaiveBayesModel"),
function(object, ...) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
labels <- callJMethod(jobj, "labels")
apriori <- callJMethod(jobj, "apriori")
apriori <- t(as.matrix(unlist(apriori)))
colnames(apriori) <- unlist(labels)
tables <- callJMethod(jobj, "tables")
tables <- matrix(tables, nrow = length(labels))
rownames(tables) <- unlist(labels)
colnames(tables) <- unlist(features)
return(list(apriori = apriori, tables = tables))
})

#' Fit a k-means model
#'
#' Fit a k-means model, similarly to R's kmeans().
Expand All @@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"),
#' @rdname kmeans
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#'}
setMethod("kmeans", signature(x = "DataFrame"),
Expand All @@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' @rdname fitted
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
Expand All @@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
stop(paste("Unsupported model", modelName, sep = " "))
}
})

#' Fit a Bernoulli naive Bayes model
#'
#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
#' categorical features are supported. The input should be a DataFrame of observations instead of a
#' contingency table.
#'
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param data DataFrame for training
#' @param laplace Smoothing parameter
#' @return a fitted naive Bayes model
#' @rdname naiveBayes
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(sqlContext, infert)
#' model <- naiveBayes(education ~ ., df, laplace = 0)
#'}
setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
function(formula, data, laplace = 0, ...) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
formula, data@sdf, laplace)
return(new("NaiveBayesModel", jobj = jobj))
})
59 changes: 59 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Expand Up @@ -141,3 +141,62 @@ test_that("kmeans", {
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
})

test_that("naiveBayes", {
# R code to reproduce the result.
# We do not support instance weights yet. So we ignore the frequencies.
#
#' library(e1071)
#' t <- as.data.frame(Titanic)
#' t1 <- t[t$Freq > 0, -5]
#' m <- naiveBayes(Survived ~ ., data = t1)
#' m
#' predict(m, t1)
#
# -- output of 'm'
#
# A-priori probabilities:
# Y
# No Yes
# 0.4166667 0.5833333
#
# Conditional probabilities:
# Class
# Y 1st 2nd 3rd Crew
# No 0.2000000 0.2000000 0.4000000 0.2000000
# Yes 0.2857143 0.2857143 0.2857143 0.1428571
#
# Sex
# Y Male Female
# No 0.5 0.5
# Yes 0.5 0.5
#
# Age
# Y Child Adult
# No 0.2000000 0.8000000
# Yes 0.4285714 0.5714286
#
# -- output of 'predict(m, t1)'
#
# Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
#

t <- as.data.frame(Titanic)
t1 <- t[t$Freq > 0, -5]
df <- suppressWarnings(createDataFrame(sqlContext, t1))
m <- naiveBayes(Survived ~ ., data = df)
s <- summary(m)
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
p <- collect(select(predict(m, df), "prediction"))
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
"Yes", "Yes", "No", "No"))

# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})
3 changes: 2 additions & 1 deletion R/pkg/inst/tests/testthat/test_sparkSQL.R
Expand Up @@ -1817,7 +1817,8 @@ test_that("approxQuantile() on a DataFrame", {

test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table not found: blah", retError), TRUE)
expect_equal(grepl("Table not found", retError), TRUE)
expect_equal(grepl("blah", retError), TRUE)
})

irisDF <- suppressWarnings(createDataFrame(sqlContext, iris))
Expand Down
Expand Up @@ -32,6 +32,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
Expand Down Expand Up @@ -163,19 +164,22 @@ public final class BytesToBytesMap extends MemoryConsumer {
private long peakMemoryUsedBytes = 0L;

private final BlockManager blockManager;
private final SerializerManager serializerManager;
private volatile MapIterator destructiveIterator = null;
private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();

public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
this.serializerManager = serializerManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
Expand Down Expand Up @@ -209,6 +213,7 @@ public BytesToBytesMap(
this(
taskMemoryManager,
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
initialCapacity,
0.70,
pageSizeBytes,
Expand Down Expand Up @@ -271,7 +276,7 @@ private void advanceToNextPage() {
}
try {
Closeables.close(reader, /* swallowIOException = */ false);
reader = spillWriters.getFirst().getReader(blockManager);
reader = spillWriters.getFirst().getReader(serializerManager);
recordsInPage = -1;
} catch (IOException e) {
// Scala iterator does not handle exception
Expand Down
Expand Up @@ -31,6 +31,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
Expand All @@ -51,6 +52,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final RecordComparator recordComparator;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final SerializerManager serializerManager;
private final TaskContext taskContext;
private ShuffleWriteMetrics writeMetrics;

Expand Down Expand Up @@ -78,14 +80,16 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
pageSizeBytes, inMemorySorter);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
Expand All @@ -95,18 +99,20 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
public static UnsafeExternalSorter create(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
}

private UnsafeExternalSorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
Expand All @@ -116,6 +122,7 @@ private UnsafeExternalSorter(
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
this.serializerManager = serializerManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.prefixComparator = prefixComparator;
Expand Down Expand Up @@ -412,7 +419,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
Expand Down Expand Up @@ -463,7 +470,7 @@ public long spill() throws IOException {
}
spillWriter.close();
spillWriters.add(spillWriter);
nextUpstream = spillWriter.getReader(blockManager);
nextUpstream = spillWriter.getReader(serializerManager);

long released = 0L;
synchronized (UnsafeExternalSorter.this) {
Expand Down Expand Up @@ -549,7 +556,7 @@ public UnsafeSorterIterator getIterator() throws IOException {
} else {
LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
queue.add(spillWriter.getReader(blockManager));
queue.add(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
queue.add(inMemSorter.getSortedIterator());
Expand Down

0 comments on commit bad5c3e

Please sign in to comment.